import os
from typing import List

import torch
import torchvision.transforms as T
import torchvision.utils
from torch import Tensor


def imshow(images: Tensor, mode: str = 'RGB'):
    """ show images

    :param images: images, shape is (B, C, H, W)
    :param mode: the mode of images, e.g., 'L', 'RGB'
    """
    toPILImage = T.ToPILImage()
    for image in images:
        PILImage = toPILImage(image).convert(mode)
        PILImage.show()


def imsave(images: Tensor, filepath: str, filenames: List[str]):
    images_cpu = images.clone().detach().cpu()
    for index, image in enumerate(images_cpu):
        image = torch.unsqueeze(image, dim=0)
        filename = filenames[index]
        torchvision.utils.save_image(image, os.path.join(filepath, filename))
